Note: if you feel like your Pandas skills need a bit of a touch up, check this article out!
The dataset we're going to be using is one of the most realistic retail time-series datasets you'll find out there because, well, it's actual Walmart data. It was made available for a Kaggle competition that you can check out here.
The original format of the data was in a "wide" format to made it smaller in memory, but that doesn't really work too well with databases and you won't see that very often in the real world. The most notable changes are that I added in a date column to replace the date identifier columns that were previously there, and I made the data smaller by only subsetting to the state of Texas.
Another note: If you want to develop and test your code with a smaller dataset (which I'd probably recommend), set sampled in the cell below to True. All of the tests will still pass if your code is correct!
Let's get into it!
data_dir = '/kaggle/input/project-1-data/data'
sampled = False
path_suffix = '' if not sampled else '_sampled'
import pandas as pd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.graphics.tsaplots import plot_pacf, plot_acf
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
Time-series data has to be collected from some real-world, data-generating process. That means that raw data comes in as a series of observations. Depending on your experience with time-series data, you may be used to data that looks like this:
| Date | Sales |
|---|---|
| 2022-01-01 | 23 |
| 2022-01-02 | 45 |
| 2022-01-03 | 12 |
| 2022-01-04 | 67 |
| 2022-01-05 | 89 |
But, if you're in retail, each of those "sales" probably came in some JSON from some point-of-sale system (i.e. cash register) that probably looked something like this:
{
"timestamp": 2022-01-01 12:34:56,
"product_id": 5,
"store_id": 12,
"category_id": 36,
...
}
Usually, it's the job of a data engineer to collect all of these records and aggregate them into a nice, tabular format, but it's worth at least having an appreciation for how it's done. So, we're going to start from a mock version of a transactions table.
transactions = pd.read_csv(f'{data_dir}/transactions_data{path_suffix}.csv')
transactions.head()
| date | id | item_id | dept_id | cat_id | store_id | state_id | |
|---|---|---|---|---|---|---|---|
| 0 | 2013-01-01 13:41:03 | HOBBIES_1_004_TX_1_evaluation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 1 | 2013-01-01 07:30:52 | HOBBIES_1_004_TX_1_evaluation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 2 | 2013-01-01 11:17:38 | HOBBIES_1_004_TX_1_evaluation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 3 | 2013-01-01 20:18:59 | HOBBIES_1_025_TX_1_evaluation | HOBBIES_1_025 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 4 | 2013-01-01 21:36:09 | HOBBIES_1_028_TX_1_evaluation | HOBBIES_1_028 | HOBBIES_1 | HOBBIES | TX_1 | TX |
transactions.dtypes
date object id object item_id object dept_id object cat_id object store_id object state_id object dtype: object
You can see that this is a DataFrame where each row relates to purchases for an individual item. Here's a little data dictionary:
date: the time at which an item was bought, down to the secondid: the product ID. Each of these is an individual item at a specific store.item_id: this is an identifier for items, but not at the store level. You can use this to find the same item at different stores.dept_id: department ID. One level up from item_id in the hierarchycat_id: category ID. One level up from dept_id in the hierarchystore_id: identifies the specific store where the product was boughtstate_id: identifies the specific state where the product was boughtdate is supposed to be a datetime-like object, but you can see that when we loaded it from disk, it was loaded in as a string. Let's convert that column to datetime.
# QUESTION: Convert this column to a datetime object
transactions['date'] = pd.to_datetime(transactions['date'])
transactions.dtypes
date datetime64[ns] id object item_id object dept_id object cat_id object store_id object state_id object dtype: object
transactions
| date | id | item_id | dept_id | cat_id | store_id | state_id | |
|---|---|---|---|---|---|---|---|
| 0 | 2013-01-01 13:41:03 | HOBBIES_1_004_TX_1_evaluation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 1 | 2013-01-01 07:30:52 | HOBBIES_1_004_TX_1_evaluation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 2 | 2013-01-01 11:17:38 | HOBBIES_1_004_TX_1_evaluation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 3 | 2013-01-01 20:18:59 | HOBBIES_1_025_TX_1_evaluation | HOBBIES_1_025 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| 4 | 2013-01-01 21:36:09 | HOBBIES_1_028_TX_1_evaluation | HOBBIES_1_028 | HOBBIES_1 | HOBBIES | TX_1 | TX |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 12905710 | 2016-05-22 18:30:53 | FOODS_3_825_TX_3_evaluation | FOODS_3_825 | FOODS_3 | FOODS | TX_3 | TX |
| 12905711 | 2016-05-22 08:05:28 | FOODS_3_826_TX_3_evaluation | FOODS_3_826 | FOODS_3 | FOODS | TX_3 | TX |
| 12905712 | 2016-05-22 14:56:59 | FOODS_3_826_TX_3_evaluation | FOODS_3_826 | FOODS_3 | FOODS | TX_3 | TX |
| 12905713 | 2016-05-22 16:43:00 | FOODS_3_827_TX_3_evaluation | FOODS_3_827 | FOODS_3 | FOODS | TX_3 | TX |
| 12905714 | 2016-05-22 06:18:25 | FOODS_3_827_TX_3_evaluation | FOODS_3_827 | FOODS_3 | FOODS | TX_3 | TX |
12905715 rows × 7 columns
Our goal is to transform this dataset into one that's easy to analyze and train models on. For this project, our goal is going to be to work at the daily level. So, our first step is to aggregate our transactions data up to the daily level.
To be more specific, this is what we want it to look like:
# This is a hefty table, so just peeking at the first 5 rows
pd.read_csv(f'{data_dir}/sales_data{path_suffix}.csv', nrows=5)
| date | id | item_id | dept_id | cat_id | store_id | state_id | sales | |
|---|---|---|---|---|---|---|---|---|
| 0 | 2013-01-01 | HOBBIES_1_001_TX_1_evaluation | HOBBIES_1_001 | HOBBIES_1 | HOBBIES | TX_1 | TX | 0 |
| 1 | 2013-01-01 | HOBBIES_1_002_TX_1_evaluation | HOBBIES_1_002 | HOBBIES_1 | HOBBIES | TX_1 | TX | 0 |
| 2 | 2013-01-01 | HOBBIES_1_003_TX_1_evaluation | HOBBIES_1_003 | HOBBIES_1 | HOBBIES | TX_1 | TX | 0 |
| 3 | 2013-01-01 | HOBBIES_1_004_TX_1_evaluation | HOBBIES_1_004 | HOBBIES_1 | HOBBIES | TX_1 | TX | 3 |
| 4 | 2013-01-01 | HOBBIES_1_005_TX_1_evaluation | HOBBIES_1_005 | HOBBIES_1 | HOBBIES | TX_1 | TX | 0 |
You can see that the sales column is really just a daily count of transactions for that particular id.
In the cell below, create a dataframe called data, which is the transactions dataframe aggregated to the daily level. It should look like the above, except you won't have zero sales days. Don't worry about order: the below test will handle that!
data = (
transactions
.assign(
date = lambda df: pd.to_datetime(df.date.dt.date)
)
.pipe(lambda df: df.groupby(list(df.columns))['id'].count())
.reset_index(name='sales')
)
data
| date | id | item_id | dept_id | cat_id | store_id | state_id | sales | |
|---|---|---|---|---|---|---|---|---|
| 0 | 2013-01-01 | FOODS_1_004_TX_1_evaluation | FOODS_1_004 | FOODS_1 | FOODS | TX_1 | TX | 20 |
| 1 | 2013-01-01 | FOODS_1_004_TX_2_evaluation | FOODS_1_004 | FOODS_1 | FOODS | TX_2 | TX | 20 |
| 2 | 2013-01-01 | FOODS_1_004_TX_3_evaluation | FOODS_1_004 | FOODS_1 | FOODS | TX_3 | TX | 4 |
| 3 | 2013-01-01 | FOODS_1_005_TX_2_evaluation | FOODS_1_005 | FOODS_1 | FOODS | TX_2 | TX | 1 |
| 4 | 2013-01-01 | FOODS_1_009_TX_2_evaluation | FOODS_1_009 | FOODS_1 | FOODS | TX_2 | TX | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 3895933 | 2016-05-22 | HOUSEHOLD_2_511_TX_3_evaluation | HOUSEHOLD_2_511 | HOUSEHOLD_2 | HOUSEHOLD | TX_3 | TX | 4 |
| 3895934 | 2016-05-22 | HOUSEHOLD_2_513_TX_1_evaluation | HOUSEHOLD_2_513 | HOUSEHOLD_2 | HOUSEHOLD | TX_1 | TX | 2 |
| 3895935 | 2016-05-22 | HOUSEHOLD_2_514_TX_3_evaluation | HOUSEHOLD_2_514 | HOUSEHOLD_2 | HOUSEHOLD | TX_3 | TX | 1 |
| 3895936 | 2016-05-22 | HOUSEHOLD_2_516_TX_2_evaluation | HOUSEHOLD_2_516 | HOUSEHOLD_2 | HOUSEHOLD | TX_2 | TX | 1 |
| 3895937 | 2016-05-22 | HOUSEHOLD_2_516_TX_3_evaluation | HOUSEHOLD_2_516 | HOUSEHOLD_2 | HOUSEHOLD | TX_3 | TX | 2 |
3895938 rows × 8 columns
If the cell below runs without error, you did it right!
def test_sales_eq(data):
assert (
pd.read_csv(f'{data_dir}/sales_data{path_suffix}.csv', usecols=['date', 'id', 'sales'])
.assign(date=lambda df: pd.to_datetime(df.date))
.query('sales != 0')
.merge(data, on=['date', 'id'], how='left', suffixes=('_actual', '_predicted'))
.fillna(0)
.assign(sales_error=lambda df: (df.sales_actual - df.sales_predicted).abs())
.sales_error
.sum() < 1e-6
), 'Your version of sales does not match the original sales data.'
assert (
pd.read_csv(f'{data_dir}/sales_data{path_suffix}.csv', usecols=['date', 'id', 'sales'])
.query('sales != 0')
.shape[0]
) == data.shape[0], 'Your dataframe has a different number of rows than the original sales data.'
test_sales_eq(data)
Let's take a look at how our data is being stored in memory.
data.info(memory_usage='deep')
<class 'pandas.core.frame.DataFrame'> RangeIndex: 3895938 entries, 0 to 3895937 Data columns (total 8 columns): # Column Dtype --- ------ ----- 0 date datetime64[ns] 1 id object 2 item_id object 3 dept_id object 4 cat_id object 5 store_id object 6 state_id object 7 sales int64 dtypes: datetime64[ns](1), int64(1), object(6) memory usage: 1.5 GB
1.5 GB of data for our purposed is certainly no joke. But how much of that is really necessary?
Most of our data is stored in the least memory efficient format for pandas: strings (objects). Let's fix that.
Hint: check out this page of the pandas documentation that talks about data types.
In the below cell, convert the data types of columns to reduce memory usage as much as possible.
data = (
data
.assign(
id = lambda df: df.id.astype('category'),
item_id = lambda df: df.item_id.astype('category'),
cat_id = lambda df: df.cat_id.astype('category'),
store_id = lambda df: df.store_id.astype('category'),
state_id = lambda df: df.state_id.astype('category'),
dept_id = lambda df: df.dept_id.astype('category')
)
)
data.info(memory_usage='deep')
<class 'pandas.core.frame.DataFrame'> RangeIndex: 3895938 entries, 0 to 3895937 Data columns (total 8 columns): # Column Dtype --- ------ ----- 0 date datetime64[ns] 1 id category 2 item_id category 3 dept_id category 4 cat_id category 5 store_id category 6 state_id category 7 sales int64 dtypes: category(6), datetime64[ns](1), int64(1) memory usage: 90.4 MB
In my solution, I got the final DataFrame down to 90.4 MB, which is about 6% of the original size!
While we're at it, it's worth talking about the best way to store this data on disk. If we saved this as a CSV, it wouldn't maintain any of the data type modifications we just made. Pandas offers a bunch of options for saving DataFrames, but here are the two I'd recommend:
Parquet has basically become the industry standard for storing tabular data on disk. It's a columnar file format that automatically compresses your data (which it does really well) and will maintain any data types you use in Pandas, with only a couple exceptions.
Feather is also a columnar data format, but it optimizes heavily for read speed. Your file size will be much bigger than Parquets, but it's really useful when you need to heavily optimize for data reading.
data.to_parquet('sales_data_checkpoint.parquet')
data = pd.read_parquet('sales_data_checkpoint.parquet')
data.info(memory_usage='deep')
<class 'pandas.core.frame.DataFrame'> RangeIndex: 3895938 entries, 0 to 3895937 Data columns (total 8 columns): # Column Dtype --- ------ ----- 0 date datetime64[ns] 1 id category 2 item_id category 3 dept_id category 4 cat_id category 5 store_id category 6 state_id category 7 sales int64 dtypes: category(6), datetime64[ns](1), int64(1) memory usage: 90.4 MB
On my local machine, loading our original CSV took ~8.7 seconds, and that only took 0.1 seconds. And our data types were maintained! Nice!
There's one last modification we need to make to our data before it's ready to go. The way that we converted transactions into sales was slightly problematic because now, when a product doesn't sell it just isn't present in our data, rather than appearing as a zero.
That's an issue for our forecasting models, so let's fix it!
First, set your index to columns that the DataFrame is distinct on (date and id).
data = data.set_index(['date','id'])
Now, create a MultiIndex with all combinations of daily dates and ids using pd.MultiIndex.from_product and use it and .reindex() to fill the gaps in your data.
# Your code here
dates = data.index.get_level_values('date').unique()
ids = data.index.get_level_values('id').unique()
index_to_select = pd.MultiIndex.from_product([dates, ids], names=['date', 'id'])
data = (
data
.reindex(index_to_select)
.sort_index()
)
Finally, fill the resulting NaNs in your dataframe. Hint: it's tempting to use .groupby().fillna(method='ffill') (and backfilling), but unfortunately this method is quite slow on grouped data. I'd recommend manually recreating the categorical columns by splitting the id column on underscores. This cell could take over a minute to run depending on how you implement it!
df_full = pd.DataFrame()
for i in data.index.get_level_values('id').unique():
df_aux = data.loc[pd.IndexSlice[:, i], :]
df_aux['item_id'] = df_aux['item_id'].fillna(method = 'bfill').fillna(method = 'ffill')
df_aux['dept_id'] = df_aux['dept_id'].fillna(method = 'bfill').fillna(method = 'ffill')
df_aux['cat_id'] = df_aux['cat_id'].fillna(method = 'bfill').fillna(method = 'ffill')
df_aux['store_id'] = df_aux['store_id'].fillna(method = 'bfill').fillna(method = 'ffill')
df_aux['state_id'] = df_aux['state_id'].fillna(method = 'bfill').fillna(method = 'ffill')
df_aux['sales'] = df_aux['sales'].fillna(0)
df_full = pd.concat([df_aux, df_full],axis = 0)
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy import sys /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:8: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:9: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy if __name__ == "__main__": /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:10: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy # Remove the CWD from sys.path while we load stuff. /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:11: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy # This is added back by InteractiveShellApp.init_path()
df_full.dtypes
item_id category dept_id category cat_id category store_id category state_id category sales float64 dtype: object
data = df_full.copy()
data.isna().sum()
item_id 0 dept_id 0 cat_id 0 store_id 0 state_id 0 sales 0 dtype: int64
print('Sales sum in original data: {}'.format(pd.read_csv(f'{data_dir}/sales_data{path_suffix}.csv', usecols=['date', 'id', 'sales'])['sales'].sum()))
print('Sales sum in my data: {}'.format(int(data.sales.sum())))
Sales sum in original data: 12905715 Sales sum in my data: 12905715
def test_sales_eq(data):
assert (
pd.read_csv(f'{data_dir}/sales_data{path_suffix}.csv', usecols=['date', 'id', 'sales'])
.assign(date=lambda df: pd.to_datetime(df.date))
.merge(data, on=['date', 'id'], how='left', suffixes=('_actual', '_predicted'))
.fillna({'sales_actual': 0, 'sales_predicted': 0})
.assign(sales_error=lambda df: (df.sales_actual - df.sales_predicted).abs())
.sales_error
.sum() < 1e-6
), 'Your version of sales does not match the original sales data.'
test_sales_eq(data)
Exploratory data analysis is crucial for building the best models.
Before you start this section, though, I would highly recommend that you set the index of your DataFrame to be on both the date and id field (our DataFrame has one row for each date/id combo). It's up to you, but it's good practice!
For this section, find 3-5 insights about the data that you feel are helpful for building models. Specifically, we'll be building models at the date/dept_id level (i.e., a forecast for FOODS_1 on 2011-02-01, 2011-02-02, etc., a forecast for HOBBIES_1 on 2011-02-01, 2011-02-02, etc.)
The only required one is an autocorrelation analysis. Other than that, some ideas are:
Anything goes! Be creative!
Here's an example of plotting the category-level sales for FOODS_1 to get you started:
(
data
.groupby(['date', 'dept_id'])
.sales
.sum()
[:, 'FOODS_1']
.plot()
)
departments = list(data.dept_id.unique())
categories = list(data.cat_id.unique())
dept_sales = data.groupby(['date', 'dept_id']).sales.sum().unstack('dept_id').fillna(0)
cat_sales = data.groupby(['date', 'cat_id']).sales.sum().unstack('cat_id').fillna(0)
sns.set_style('whitegrid')
for i, cat in enumerate(categories):
fig, ax = plt.subplots(figsize=(12, 6))
sns.lineplot(data=cat_sales,
x = cat_sales.index,
y = cat,
color = 'lightsteelblue',
ax = ax,
label = 'y'
)
sns.lineplot(data=cat_sales.rolling(30, min_periods = 1).mean(),
x=cat_sales.index,
y=cat,
color='indianred',
ax=ax,
label = '30days MA'
)
sns.lineplot(data=cat_sales.rolling(365, min_periods = 1).mean(),
x=cat_sales.index,
y=cat,
color='darkblue',
ax=ax,
label = '365days MA'
)
ax.set_title('Category: {}'.format(cat))
sns.set_style('whitegrid')
for i, dep in enumerate(departments):
fig, ax = plt.subplots(figsize=(12, 6)
)
sns.lineplot(data=dept_sales,
x = dept_sales.index,
y = dep,
color = 'lightsteelblue',
ax = ax,
label = 'y'
)
sns.lineplot(data=dept_sales.rolling(30, min_periods = 1).mean(),
x=dept_sales.index,
y=dep,
color='indianred',
ax=ax,
label = '30days MA'
)
sns.lineplot(data=dept_sales.rolling(365, min_periods = 1).mean(),
x=dept_sales.index,
y=dep,
color='darkblue',
ax=ax,
label = '365days MA'
)
ax.set_title('Department: {}'.format(dep))
from prophet import Prophet
df_prophet_dept_id = pd.melt(dept_sales.reset_index(), id_vars = ['date'], var_name='dept_id', value_name='y').rename(columns = {'date':'ds'})
for i, dept in enumerate(departments):
model = Prophet(
seasonality_mode='multiplicative',
weekly_seasonality=True,
yearly_seasonality=4,
changepoint_prior_scale=0.05
)
model.add_seasonality(
name='monthly',
period=365.25/12,
fourier_order=4,
mode='multiplicative'
)
model.fit(df_prophet_dept_id[df_prophet_dept_id['dept_id'] == dept])
future = model.make_future_dataframe(periods=0)
forecast = model.predict(future)
# Plot the forecast components
print(f'Department: {dept}')
fig1 = model.plot(forecast)
fig = model.plot_components(forecast)
# fig = model.plot_components(forecast)
plt.show()
11:54:24 - cmdstanpy - INFO - Chain [1] start processing 11:54:24 - cmdstanpy - INFO - Chain [1] done processing
Department: HOUSEHOLD_2
11:54:32 - cmdstanpy - INFO - Chain [1] start processing 11:54:33 - cmdstanpy - INFO - Chain [1] done processing
Department: HOUSEHOLD_1
11:54:36 - cmdstanpy - INFO - Chain [1] start processing 11:54:36 - cmdstanpy - INFO - Chain [1] done processing
Department: HOBBIES_2
11:54:39 - cmdstanpy - INFO - Chain [1] start processing 11:54:39 - cmdstanpy - INFO - Chain [1] done processing
Department: HOBBIES_1
11:54:42 - cmdstanpy - INFO - Chain [1] start processing 11:54:43 - cmdstanpy - INFO - Chain [1] done processing
Department: FOODS_3
11:54:46 - cmdstanpy - INFO - Chain [1] start processing 11:54:46 - cmdstanpy - INFO - Chain [1] done processing
Department: FOODS_2
11:54:48 - cmdstanpy - INFO - Chain [1] start processing 11:54:49 - cmdstanpy - INFO - Chain [1] done processing
Department: FOODS_1
dept_sales = data.groupby(['date', 'dept_id']).sales.sum().unstack('dept_id').fillna(0)
# Resample data to get weekly and daily sales data for each department
monthly_sales = dept_sales.resample('M').sum()
daily_sales = dept_sales.resample('D').sum()
quarterly_sales = dept_sales.resample('Q').sum()
# Plot daily autocorrelation for all departments
fig, axs = plt.subplots(len(departments), 1, figsize=(8, 15))
for i, dept in enumerate(departments):
plot_acf(daily_sales[dept], ax=axs[i], lags=30)
axs[i].set_title('Daily ACF (1 Month) - {} Department Sales'.format(dept))
plt.tight_layout()
plt.show()
# Plot daily autocorrelation for all departments
fig, axs = plt.subplots(len(departments), 1, figsize=(8, 15))
for i, dept in enumerate(departments):
plot_pacf(daily_sales[dept], ax=axs[i], lags=30)
axs[i].set_title('Daily PACF (1 Month) - {} Department Sales'.format(dept))
plt.tight_layout()
plt.show()
/opt/conda/lib/python3.7/site-packages/statsmodels/graphics/tsaplots.py:353: FutureWarning: The default method 'yw' can produce PACF values outside of the [-1,1] interval. After 0.13, the default will change tounadjusted Yule-Walker ('ywm'). You can use this method now by setting method='ywm'.
FutureWarning,
# Plot monthly autocorrelation for all departments
fig, axs = plt.subplots(len(departments), 1, figsize=(8, 15))
for i, dept in enumerate(departments):
plot_acf(monthly_sales[dept], ax=axs[i], lags=16)
axs[i].set_title('Monthly ACF - {} Department Sales'.format(dept))
plt.tight_layout()
plt.show()
# Plot monthly partial autocorrelation for all departments
fig, axs = plt.subplots(len(departments), 1, figsize=(8, 15))
for i, dept in enumerate(departments):
plot_pacf(monthly_sales[dept], ax=axs[i], method='ols', lags=16)
axs[i].set_title('Monthly PACF - {} Department Sales'.format(dept))
plt.tight_layout()
plt.show()
Finally, we can train some models!
We're going to use the statsforecast library, since it makes training statistical time-series models really easy. There are other great libraries (like darts, which is more mature of a package) but I like statsforecast a bit more for these models. Eventually, we'll get to training our own models from scratch.
Here's what you need to do:
date/dept_id level so each date has 7 distinct records (one for each dept_id).statsforecast likes.statsforecast so feel free to fit whatever you want, but focus on models like this one and this one since we discussed them. Their documentation has a quickstart to get going. I provided you with some helper code below to get started.
darts and see how they compare, etc.# You'll see a lot of "WARNING: Retrying" outputs if you don't have internet enabled
# This happens because you haven't verified your phone number for your Kaggle profile.
# To do that, you'll need to exit the notebook, go to your profile, go to "account", and verify your phone number.
# Other than that, don't worry about any error outputs you see from this
! pip install statsforecast==1.5.0
Collecting statsforecast==1.5.0
Downloading statsforecast-1.5.0-py3-none-any.whl (99 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0/100.0 kB 3.1 MB/s eta 0:00:00
Requirement already satisfied: scipy>=1.7.3 in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (1.7.3)
Requirement already satisfied: matplotlib in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (3.5.3)
Collecting plotly-resampler
Downloading plotly_resampler-0.8.3.2.tar.gz (46 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.4/46.4 kB 3.6 MB/s eta 0:00:00
Installing build dependencies ... - \ | / - \ | / - \ done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (4.64.1)
Requirement already satisfied: plotly in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (5.13.0)
Requirement already satisfied: numba>=0.55.0 in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (0.56.4)
Requirement already satisfied: statsmodels>=0.13.2 in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (0.13.5)
Collecting fugue>=0.8.1
Downloading fugue-0.8.1-py3-none-any.whl (364 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 364.4/364.4 kB 11.0 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.21.6 in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (1.21.6)
Requirement already satisfied: pandas>=1.3.5 in /opt/conda/lib/python3.7/site-packages (from statsforecast==1.5.0) (1.3.5)
Collecting triad>=0.8.1
Downloading triad-0.8.3-py3-none-any.whl (72 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 72.5/72.5 kB 6.0 MB/s eta 0:00:00
Collecting adagio>=0.2.4
Downloading adagio-0.2.4-py3-none-any.whl (26 kB)
Requirement already satisfied: sqlalchemy in /opt/conda/lib/python3.7/site-packages (from fugue>=0.8.1->statsforecast==1.5.0) (1.4.46)
Requirement already satisfied: pyarrow>=0.15.1 in /opt/conda/lib/python3.7/site-packages (from fugue>=0.8.1->statsforecast==1.5.0) (6.0.1)
Collecting fugue-sql-antlr>=0.1.5
Downloading fugue-sql-antlr-0.1.5.tar.gz (154 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.1/154.1 kB 11.8 MB/s eta 0:00:00
Preparing metadata (setup.py) ... - done
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.7/site-packages (from fugue>=0.8.1->statsforecast==1.5.0) (3.1.2)
Collecting qpd>=0.4.0
Downloading qpd-0.4.0-py3-none-any.whl (187 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 187.8/187.8 kB 14.6 MB/s eta 0:00:00
Collecting sqlglot
Downloading sqlglot-11.3.7-py3-none-any.whl (238 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 238.2/238.2 kB 15.7 MB/s eta 0:00:00
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from numba>=0.55.0->statsforecast==1.5.0) (4.11.4)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /opt/conda/lib/python3.7/site-packages (from numba>=0.55.0->statsforecast==1.5.0) (0.39.1)
Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from numba>=0.55.0->statsforecast==1.5.0) (59.8.0)
Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.3.5->statsforecast==1.5.0) (2022.7.1)
Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas>=1.3.5->statsforecast==1.5.0) (2.8.2)
Requirement already satisfied: patsy>=0.5.2 in /opt/conda/lib/python3.7/site-packages (from statsmodels>=0.13.2->statsforecast==1.5.0) (0.5.3)
Requirement already satisfied: packaging>=21.3 in /opt/conda/lib/python3.7/site-packages (from statsmodels>=0.13.2->statsforecast==1.5.0) (23.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.7/site-packages (from matplotlib->statsforecast==1.5.0) (4.38.0)
Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.7/site-packages (from matplotlib->statsforecast==1.5.0) (9.4.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.7/site-packages (from matplotlib->statsforecast==1.5.0) (1.4.4)
Requirement already satisfied: pyparsing>=2.2.1 in /opt/conda/lib/python3.7/site-packages (from matplotlib->statsforecast==1.5.0) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.7/site-packages (from matplotlib->statsforecast==1.5.0) (0.11.0)
Requirement already satisfied: tenacity>=6.2.0 in /opt/conda/lib/python3.7/site-packages (from plotly->statsforecast==1.5.0) (8.1.0)
Collecting jupyter-dash>=0.4.2
Downloading jupyter_dash-0.4.2-py3-none-any.whl (23 kB)
Collecting trace-updater>=0.0.8
Downloading trace_updater-0.0.9-py3-none-any.whl (185 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 185.1/185.1 kB 15.6 MB/s eta 0:00:00
Collecting dash<3.0.0,>=2.2.0
Downloading dash-2.9.0-py3-none-any.whl (10.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.2/10.2 MB 66.9 MB/s eta 0:00:00
Requirement already satisfied: orjson<4.0.0,>=3.8.0 in /opt/conda/lib/python3.7/site-packages (from plotly-resampler->statsforecast==1.5.0) (3.8.5)
Collecting dash-table==5.0.0
Downloading dash_table-5.0.0-py3-none-any.whl (3.9 kB)
Requirement already satisfied: Flask>=1.0.4 in /opt/conda/lib/python3.7/site-packages (from dash<3.0.0,>=2.2.0->plotly-resampler->statsforecast==1.5.0) (2.2.3)
Collecting dash-core-components==2.0.0
Downloading dash_core_components-2.0.0-py3-none-any.whl (3.8 kB)
Collecting dash-html-components==2.0.0
Downloading dash_html_components-2.0.0-py3-none-any.whl (4.1 kB)
Collecting antlr4-python3-runtime<4.12,>=4.11.1
Downloading antlr4_python3_runtime-4.11.1-py3-none-any.whl (144 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 144.2/144.2 kB 8.9 MB/s eta 0:00:00
Requirement already satisfied: nest-asyncio in /opt/conda/lib/python3.7/site-packages (from jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (1.5.6)
Requirement already satisfied: retrying in /opt/conda/lib/python3.7/site-packages (from jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (1.3.4)
Collecting ansi2html
Downloading ansi2html-1.8.0-py3-none-any.whl (16 kB)
Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (2.28.2)
Requirement already satisfied: ipython in /opt/conda/lib/python3.7/site-packages (from jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (7.34.0)
Requirement already satisfied: ipykernel in /opt/conda/lib/python3.7/site-packages (from jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (6.16.2)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib->statsforecast==1.5.0) (4.4.0)
Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from patsy>=0.5.2->statsmodels>=0.13.2->statsforecast==1.5.0) (1.16.0)
Collecting fs
Downloading fs-2.4.16-py2.py3-none-any.whl (135 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 135.3/135.3 kB 10.4 MB/s eta 0:00:00
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->numba>=0.55.0->statsforecast==1.5.0) (3.11.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.7/site-packages (from jinja2->fugue>=0.8.1->statsforecast==1.5.0) (2.1.1)
Requirement already satisfied: greenlet!=0.4.17 in /opt/conda/lib/python3.7/site-packages (from sqlalchemy->fugue>=0.8.1->statsforecast==1.5.0) (2.0.1)
Requirement already satisfied: Werkzeug>=2.2.2 in /opt/conda/lib/python3.7/site-packages (from Flask>=1.0.4->dash<3.0.0,>=2.2.0->plotly-resampler->statsforecast==1.5.0) (2.2.3)
Requirement already satisfied: click>=8.0 in /opt/conda/lib/python3.7/site-packages (from Flask>=1.0.4->dash<3.0.0,>=2.2.0->plotly-resampler->statsforecast==1.5.0) (8.1.3)
Requirement already satisfied: itsdangerous>=2.0 in /opt/conda/lib/python3.7/site-packages (from Flask>=1.0.4->dash<3.0.0,>=2.2.0->plotly-resampler->statsforecast==1.5.0) (2.1.2)
Requirement already satisfied: appdirs~=1.4.3 in /opt/conda/lib/python3.7/site-packages (from fs->triad>=0.8.1->fugue>=0.8.1->statsforecast==1.5.0) (1.4.4)
Requirement already satisfied: debugpy>=1.0 in /opt/conda/lib/python3.7/site-packages (from ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (1.6.6)
Requirement already satisfied: pyzmq>=17 in /opt/conda/lib/python3.7/site-packages (from ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (25.0.0)
Requirement already satisfied: traitlets>=5.1.0 in /opt/conda/lib/python3.7/site-packages (from ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (5.8.1)
Requirement already satisfied: psutil in /opt/conda/lib/python3.7/site-packages (from ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (5.9.3)
Requirement already satisfied: jupyter-client>=6.1.12 in /opt/conda/lib/python3.7/site-packages (from ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (7.4.9)
Requirement already satisfied: tornado>=6.1 in /opt/conda/lib/python3.7/site-packages (from ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (6.2)
Requirement already satisfied: matplotlib-inline>=0.1 in /opt/conda/lib/python3.7/site-packages (from ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.1.6)
Requirement already satisfied: backcall in /opt/conda/lib/python3.7/site-packages (from ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.2.0)
Requirement already satisfied: decorator in /opt/conda/lib/python3.7/site-packages (from ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (5.1.1)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (3.0.36)
Requirement already satisfied: pickleshare in /opt/conda/lib/python3.7/site-packages (from ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.7.5)
Requirement already satisfied: pygments in /opt/conda/lib/python3.7/site-packages (from ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (2.14.0)
Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.7/site-packages (from ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.18.2)
Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.7/site-packages (from ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (4.8.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (3.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (1.26.14)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.7/site-packages (from requests->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (2.1.1)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (2022.12.7)
Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.7/site-packages (from jedi>=0.16->ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.8.3)
Requirement already satisfied: jupyter-core>=4.9.2 in /opt/conda/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (4.12.0)
Requirement already satisfied: entrypoints in /opt/conda/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.4)
Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.7/site-packages (from pexpect>4.3->ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.7.0)
Requirement already satisfied: wcwidth in /opt/conda/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython->jupyter-dash>=0.4.2->plotly-resampler->statsforecast==1.5.0) (0.2.6)
Building wheels for collected packages: plotly-resampler, fugue-sql-antlr
Building wheel for plotly-resampler (pyproject.toml) ... done
Created wheel for plotly-resampler: filename=plotly_resampler-0.8.3.2-cp37-cp37m-manylinux_2_31_x86_64.whl size=75906 sha256=89bf7553e0a19c3653f40b77f7aabfcb0cf437949752bd2f30ceb9488fa242e8
Stored in directory: /root/.cache/pip/wheels/20/76/b6/bd15d35379014d2c71ddfbecc2d1e01713a1ad69a600717f85
Building wheel for fugue-sql-antlr (setup.py) ... - \ done
Created wheel for fugue-sql-antlr: filename=fugue_sql_antlr-0.1.5-py3-none-any.whl size=157611 sha256=41bfcf73ac8fda627a5a51ede0c7cf1d4bfd180cc6a163ab64070715e54f08bd
Stored in directory: /root/.cache/pip/wheels/5d/dc/65/0a25b69011abd4e7a198ad0e4aa5399ae919082f6959deba31
Successfully built plotly-resampler fugue-sql-antlr
Installing collected packages: trace-updater, sqlglot, dash-table, dash-html-components, dash-core-components, antlr4-python3-runtime, fs, ansi2html, triad, fugue-sql-antlr, dash, adagio, qpd, jupyter-dash, plotly-resampler, fugue, statsforecast
Successfully installed adagio-0.2.4 ansi2html-1.8.0 antlr4-python3-runtime-4.11.1 dash-2.9.0 dash-core-components-2.0.0 dash-html-components-2.0.0 dash-table-5.0.0 fs-2.4.16 fugue-0.8.1 fugue-sql-antlr-0.1.5 jupyter-dash-0.4.2 plotly-resampler-0.8.3.2 qpd-0.4.0 sqlglot-11.3.7 statsforecast-1.5.0 trace-updater-0.0.9 triad-0.8.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
from statsforecast import StatsForecast
from statsforecast.models import ARIMA, HoltWinters, AutoARIMA, AutoETS
# Aggregate data to date/dept_id level
train_data = data.groupby(['date', 'dept_id']).sales.agg('sum').reset_index()
train_data.head()
| date | dept_id | sales | |
|---|---|---|---|
| 0 | 2013-01-01 | FOODS_1 | 727.0 |
| 1 | 2013-01-01 | FOODS_2 | 892.0 |
| 2 | 2013-01-01 | FOODS_3 | 4230.0 |
| 3 | 2013-01-01 | HOBBIES_1 | 412.0 |
| 4 | 2013-01-01 | HOBBIES_2 | 37.0 |
df = train_data.rename(columns={
'dept_id': 'unique_id',
'date': 'ds',
'sales': 'y'
})
train_df = df[df.ds < pd.Timestamp('2016-04-24')]
test_df = df[df.ds >= pd.Timestamp('2016-04-24')]
sf = StatsForecast(
models=[
# SARIMA(1, 1, 1)(1, 1, 1),7
ARIMA(order=(2, 1, 0), seasonal_order=(0, 1, 1), season_length=7),
# ETS model
HoltWinters(season_length=7),
AutoETS(),
AutoARIMA()
],
freq='D'
)
sf.fit(train_df)
StatsForecast(models=[ARIMA,HoltWinters,AutoETS,AutoARIMA])
forecast_df = sf.predict(h=28)
forecast_df.tail()
| ds | ARIMA | HoltWinters | AutoETS | AutoARIMA | |
|---|---|---|---|---|---|
| unique_id | |||||
| HOUSEHOLD_2 | 2016-05-17 | 472.271606 | 504.839935 | 570.158020 | 564.891357 |
| HOUSEHOLD_2 | 2016-05-18 | 474.971680 | 508.097931 | 569.995361 | 564.655457 |
| HOUSEHOLD_2 | 2016-05-19 | 496.143829 | 521.676636 | 569.839111 | 565.045105 |
| HOUSEHOLD_2 | 2016-05-20 | 521.533691 | 564.961609 | 569.689026 | 565.596802 |
| HOUSEHOLD_2 | 2016-05-21 | 650.241882 | 669.920959 | 569.544861 | 565.904175 |
forecast_prophet_full = pd.DataFrame()
for i, dept in enumerate(departments):
model = Prophet(
seasonality_mode='multiplicative',
weekly_seasonality=True,
yearly_seasonality=4,
changepoint_prior_scale=0.05
)
model.add_seasonality(
name='monthly',
period=365.25/12,
fourier_order=4,
mode='multiplicative'
)
model.fit(train_df[train_df['unique_id'] == dept])
future = model.make_future_dataframe(periods=28)
forecast = model.predict(future)
forecast['unique_id'] = dept
forecast = forecast[-28:][['unique_id','ds','yhat']]
forecast.rename(columns = {'yhat':'Prophet'}, inplace = True)
forecast_prophet_full = pd.concat([forecast,forecast_prophet_full],axis = 0)
forecast_prophet_full
11:56:47 - cmdstanpy - INFO - Chain [1] start processing 11:56:47 - cmdstanpy - INFO - Chain [1] done processing 11:56:48 - cmdstanpy - INFO - Chain [1] start processing 11:56:49 - cmdstanpy - INFO - Chain [1] done processing 11:56:50 - cmdstanpy - INFO - Chain [1] start processing 11:56:50 - cmdstanpy - INFO - Chain [1] done processing 11:56:51 - cmdstanpy - INFO - Chain [1] start processing 11:56:51 - cmdstanpy - INFO - Chain [1] done processing 11:56:52 - cmdstanpy - INFO - Chain [1] start processing 11:56:52 - cmdstanpy - INFO - Chain [1] done processing 11:56:53 - cmdstanpy - INFO - Chain [1] start processing 11:56:53 - cmdstanpy - INFO - Chain [1] done processing 11:56:54 - cmdstanpy - INFO - Chain [1] start processing 11:56:55 - cmdstanpy - INFO - Chain [1] done processing
| unique_id | ds | Prophet | |
|---|---|---|---|
| 1209 | FOODS_1 | 2016-04-24 | 736.028754 |
| 1210 | FOODS_1 | 2016-04-25 | 604.104755 |
| 1211 | FOODS_1 | 2016-04-26 | 597.150699 |
| 1212 | FOODS_1 | 2016-04-27 | 614.607775 |
| 1213 | FOODS_1 | 2016-04-28 | 646.021058 |
| ... | ... | ... | ... |
| 1232 | HOUSEHOLD_2 | 2016-05-17 | 522.868086 |
| 1233 | HOUSEHOLD_2 | 2016-05-18 | 528.068774 |
| 1234 | HOUSEHOLD_2 | 2016-05-19 | 545.342317 |
| 1235 | HOUSEHOLD_2 | 2016-05-20 | 597.608276 |
| 1236 | HOUSEHOLD_2 | 2016-05-21 | 726.402751 |
196 rows × 3 columns
forecast_df = pd.merge(forecast_df.reset_index(),forecast_prophet_full, on =['unique_id','ds'])
train_df = train_df.set_index(['ds','unique_id'])
test_df = test_df.set_index(['ds','unique_id'])
forecast_df = forecast_df.set_index(['ds','unique_id'])
sns.set_style('whitegrid')
test_min_date = test_df.index.get_level_values(0).min()
rmse_full = pd.DataFrame()
for i, dep in enumerate(departments):
train_aux = train_df.loc[pd.IndexSlice['2016-02-01':, dep],:]
test_aux = test_df.loc[pd.IndexSlice[:, dep],:]
forecast_aux = forecast_df.loc[pd.IndexSlice[:, dep],:]
fig, ax = plt.subplots(figsize=(12, 6))
sns.lineplot(data=train_aux,
x = train_aux.index.get_level_values(0),
y = 'y',
color = 'lightsteelblue',
ax = ax,
label = 'train'
)
sns.lineplot(data=test_aux,
x = test_aux.index.get_level_values(0),
y = 'y',
color = 'grey',
ax = ax,
label = 'test'
)
ax.axvline(x=test_min_date, color='r', linestyle='--')
for model in forecast_df.columns:
sns.lineplot(data=forecast_aux,
x = forecast_aux.index.get_level_values(0),
y = model,
ax = ax,
label = model)
rmse = pd.DataFrame()
rmse['RMSE'] = pd.Series(np.sqrt(mean_squared_error(test_aux['y'].head(28), forecast_aux[model])))
rmse['model'] = model
rmse['unique_id'] = dep
rmse_full = pd.concat([rmse,rmse_full],axis = 0)
ax.set_title('Department: {}'.format(dep))
leg_ax = ax
ax = plt.gca()
leg_ax.legend(loc='upper left', bbox_to_anchor=(1.0, 0.5))
rmse_full.pivot(columns = 'model', index = 'unique_id').round(2)
| RMSE | |||||
|---|---|---|---|---|---|
| model | ARIMA | AutoARIMA | AutoETS | HoltWinters | Prophet |
| unique_id | |||||
| FOODS_1 | 270.97 | 197.61 | 218.16 | 242.26 | 205.27 |
| FOODS_2 | 274.38 | 223.52 | 265.65 | 231.26 | 171.11 |
| FOODS_3 | 555.53 | 790.28 | 808.01 | 604.91 | 550.88 |
| HOBBIES_1 | 115.97 | 146.36 | 154.71 | 113.52 | 109.86 |
| HOBBIES_2 | 40.67 | 39.58 | 40.15 | 41.55 | 45.17 |
| HOUSEHOLD_1 | 319.59 | 344.68 | 363.01 | 272.98 | 252.71 |
| HOUSEHOLD_2 | 73.63 | 77.30 | 83.08 | 54.63 | 45.87 |